{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing import shared_memory\n",
    "import sys\n",
    "\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import xarray as xr\n",
    "\n",
    "climate_indices_home_path = \"/home/james/git/climate_indices\"\n",
    "if climate_indices_home_path not in sys.path:\n",
    "    sys.path.append(climate_indices_home_path)\n",
    "from climate_indices import compute, indices, utils\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open the precipitation dataset as an xarray Dataset object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_prcp = xr.open_dataset(\"/home/james/data/nclimgrid/nclimgrid_lowres_prcp.nc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the precipitation data and reshape the array to have the time dimension as the inner-most axis:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "da_prcp = ds_prcp['prcp'].transpose('lat', 'lon', 'time')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a NumPy array backed by shared memory and copy the precipitation data into it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a shared memory array for precipitation, copy the precipitation data\n",
    "# into it, then replace the original underlying array with the shared memory\n",
    "shm_prcp = shared_memory.SharedMemory(create=True, size=da_prcp.data.nbytes)\n",
    "shared_prcp = np.ndarray(da_prcp.shape, dtype=da_prcp.dtype, buffer=shm_prcp.buf)\n",
    "shared_prcp[:,:,:] = da_prcp[:,:,:]\n",
    "da_prcp.data = shared_prcp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'psm_44add675'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shm_name_prcp = shm_prcp.name\n",
    "shm_name_prcp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_year = int(da_prcp['time'][0].dt.year)\n",
    "calibration_year_initial = 1900\n",
    "calibration_year_final = 2000\n",
    "period_times = 12\n",
    "total_lats = da_prcp.shape[0]\n",
    "total_lons = da_prcp.shape[1]\n",
    "fitting_shape = (total_lats, total_lons, period_times)\n",
    "scales = [1]  # , 2, 3, 6, 9, 12, 24]\n",
    "periodicity = compute.Periodicity.monthly"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function that can be used to compute the gamma fitting parameters for a particular month scale:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_gammas(\n",
    "    da_precip: xr.DataArray,\n",
    "    scale: int,\n",
    "    calibration_year_initial,\n",
    "    calibration_year_final,\n",
    "    periodicity: compute.Periodicity,\n",
    ") -> (xr.DataArray, xr.DataArray):\n",
    "    \n",
    "    initial_year = int(da_precip['time'][0].dt.year)\n",
    "    if periodicity == compute.Periodicity.monthly:\n",
    "        period_times = 12\n",
    "        gamma_time_coord = \"month\"\n",
    "    elif periodicity == compute.Periodicity.daily:\n",
    "        period_times = 366\n",
    "        gamma_time_coord = \"day\"\n",
    "    gamma_coords={\"lat\": ds.lat, \"lon\": ds.lon, gamma_time_coord: range(period_times)}\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    fitting_shape = (total_lats, total_lons, period_times)\n",
    "    \n",
    "    # create a shared memory array for precipitation, copy the precipitation data\n",
    "    # into it, then replace the original underlying array with the shared memory\n",
    "    shm_prcp = shared_memory.SharedMemory(create=True, size=da_prcp.data.nbytes)\n",
    "    shared_prcp = np.ndarray(da_precip.shape, dtype=da_precip.dtype, buffer=shm_prcp.buf)\n",
    "    shared_prcp[:,:,:] = da_precip[:,:,:]\n",
    "    shm_name_prcp = shm_prcp.name\n",
    "\n",
    "    # create a shared memory array for the gamma distribution alpha and beta fitting parameters\n",
    "    shm_alpha = shared_memory.SharedMemory(create=True)  # , size=da_prcp.data.nbytes)\n",
    "    shm_name_alpha = shm_alpha.name\n",
    "    alphas = np.full(shape=fitting_shape, fill_value=np.NaN, buffer=shm_alpha.buf)\n",
    "    shm_beta = shared_memory.SharedMemory(create=True)  # , size=da_prcp.data.nbytes)\n",
    "    shm_name_beta = shm_beta.name\n",
    "    betas = np.full(shape=fitting_shape, fill_value=np.NaN, buffer=shm_beta.buf)\n",
    "\n",
    "    # loop over the grid cells and add a new arguments element for each\n",
    "                resize_arguments = {\n",
    "                \"file_id\": file_id,\n",
    "                \"image_ext\": image_ext,\n",
    "                \"annotation_format\": annotation_format,\n",
    "                \"annotation_ext\": annotation_ext,\n",
    "                \"input_images_dir\": input_images_dir,\n",
    "                \"input_annotations_dir\": input_annotations_dir,\n",
    "                \"output_images_dir\": output_images_dir,\n",
    "                \"output_annotations_dir\": output_annotations_dir,\n",
    "                \"new_width\": new_width,\n",
    "                \"new_height\": new_height,\n",
    "            }\n",
    "            resize_arguments_list.append(resize_arguments)\n",
    "\n",
    "    arguments_list = []\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "\n",
    "            arguments = {\n",
    "                'lat_index': lat_index,\n",
    "                'lon_index': lon_index,\n",
    "                'shm_name_prcp': shm_name_prcp,\n",
    "                'shm_name_alpha': shm_name_alpha,\n",
    "                'shm_name_beta': shm_name_beta,\n",
    "                'scale': scale,\n",
    "                'calibration_year_initial': calibration_year_initial,\n",
    "                'calibration_year_final': calibration_year_final,\n",
    "                'periodicity': periodicity,\n",
    "            }\n",
    "            arguments_list.append(arguments)\n",
    "\n",
    "    # use a ProcessPoolExecutor to download the images in parallel\n",
    "    with concurrent.futures.ProcessPoolExecutor() as executor:\n",
    "\n",
    "        _logger.info(\"Computing gamma fitting parameters\")\n",
    "\n",
    "        # use the executor to map the download function to the iterable of arguments\n",
    "        list(tqdm(executor.map(compute_gammas_gridcell, arguments_list), total=len(arguments_list)))\n",
    "        \n",
    "    alpha_attrs = {\n",
    "        'description': 'shape parameter of the gamma distribution (also referred to as the concentration) ' + \\\n",
    "        f'computed from the {scale}-month scaled precipitation values',\n",
    "    }\n",
    "    da_alpha = xr.DataArray(\n",
    "        data=alphas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"alpha_{str(scale).zfill(2)}\",\n",
    "        attrs=alpha_attrs,\n",
    "    )\n",
    "    beta_attrs = {\n",
    "        'description': '1 / scale of the distribution (also referred to as the rate) ' + \\\n",
    "        f'computed from the {scale}-month scaled precipitation values',\n",
    "    }\n",
    "    da_beta = xr.DataArray(\n",
    "        data=betas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"beta_{str(scale).zfill(2)}\",\n",
    "        attrs=beta_attrs,\n",
    "    )\n",
    "\n",
    "    return da_alpha, da_beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_gammas_gridcell(\n",
    "    arguments: Dict,\n",
    "):\n",
    "    \n",
    "    initial_year = int(da_precip['time'][0].dt.year)\n",
    "    if periodicity == compute.Periodicity.monthly:\n",
    "        period_times = 12\n",
    "        gamma_time_coord = \"month\"\n",
    "    elif periodicity == compute.Periodicity.daily:\n",
    "        period_times = 366\n",
    "        gamma_time_coord = \"day\"\n",
    "    gamma_coords={\"lat\": ds.lat, \"lon\": ds.lon, gamma_time_coord: range(period_times)}\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    fitting_shape = (total_lats, total_lons, period_times)\n",
    "    \n",
    "    # create a shared memory array for precipitation, copy the precipitation data\n",
    "    # into it, then replace the original underlying array with the shared memory\n",
    "    shm_prcp = shared_memory.SharedMemory(create=True, size=da_prcp.data.nbytes)\n",
    "    shared_prcp = np.ndarray(da_precip.shape, dtype=da_precip.dtype, buffer=shm_prcp.buf)\n",
    "    shared_prcp[:,:,:] = da_precip[:,:,:]\n",
    "    shm_name_prcp = shm_prcp.name\n",
    "\n",
    "    shm_alpha = shared_memory.SharedMemory(create=True)  # , size=da_prcp.data.nbytes)\n",
    "    shm_name_alpha = shm_alpha.name\n",
    "    alphas = np.full(shape=fitting_shape, fill_value=np.NaN, buffer=shm_alpha.buf)\n",
    "    shm_beta = shared_memory.SharedMemory(create=True)  # , size=da_prcp.data.nbytes)\n",
    "    shm_name_beta = shm_beta.name\n",
    "    betas = np.full(shape=fitting_shape, fill_value=np.NaN, buffer=shm_beta.buf)\n",
    "\n",
    "    # loop over the grid cells and add a new arguments element for each\n",
    "                resize_arguments = {\n",
    "                \"file_id\": file_id,\n",
    "                \"image_ext\": image_ext,\n",
    "                \"annotation_format\": annotation_format,\n",
    "                \"annotation_ext\": annotation_ext,\n",
    "                \"input_images_dir\": input_images_dir,\n",
    "                \"input_annotations_dir\": input_annotations_dir,\n",
    "                \"output_images_dir\": output_images_dir,\n",
    "                \"output_annotations_dir\": output_annotations_dir,\n",
    "                \"new_width\": new_width,\n",
    "                \"new_height\": new_height,\n",
    "            }\n",
    "            resize_arguments_list.append(resize_arguments)\n",
    "\n",
    "    arguments_list = []\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "\n",
    "            arguments = {\n",
    "                'lat_index': lat_index,\n",
    "                'lon_index': lon_index,\n",
    "                'shm_name_prcp': shm_name_prcp,\n",
    "                'shm_name_alpha': shm_name_alpha,\n",
    "                'shm_name_beta': shm_name_beta,\n",
    "                'scale': scale,\n",
    "                'calibration_year_initial': calibration_year_initial,\n",
    "                'calibration_year_final': calibration_year_final,\n",
    "                'periodicity': periodicity,\n",
    "            }\n",
    "            # get the precipitation values for the lat/lon grid cell\n",
    "            values = da_precip[lat_index, lon_index]\n",
    "\n",
    "            # skip over this grid cell if all NaN values\n",
    "            if (np.ma.is_masked(values) and values.mask.all()) or np.all(np.isnan(values)):\n",
    "                continue\n",
    "\n",
    "            # convolve to scale\n",
    "            scaled_values = \\\n",
    "                compute.scale_values(\n",
    "                    values,\n",
    "                    scale=scale,\n",
    "                    periodicity=periodicity,\n",
    "                )\n",
    "\n",
    "            # compute the fitting parameters on the scaled data\n",
    "            alphas[lat_index, lon_index], betas[lat_index, lon_index] = \\\n",
    "                compute.gamma_parameters(\n",
    "                    scaled_values,\n",
    "                    data_start_year=initial_year,\n",
    "                    calibration_start_year=calibration_year_initial,\n",
    "                    calibration_end_year=calibration_year_final,\n",
    "                    periodicity=periodicity,\n",
    "                )\n",
    "\n",
    "    # use a ProcessPoolExecutor to download the images in parallel\n",
    "    with concurrent.futures.ProcessPoolExecutor() as executor:\n",
    "\n",
    "        _logger.info(\"Resizing files\")\n",
    "\n",
    "        # use the executor to map the download function to the iterable of arguments\n",
    "        list(tqdm(executor.map(resize_image, resize_arguments_list),\n",
    "                  total=len(resize_arguments_list)))\n",
    "        \n",
    "    alpha_attrs = {\n",
    "        'description': 'shape parameter of the gamma distribution (also referred to as the concentration) ' + \\\n",
    "        f'computed from the {scale}-month scaled precipitation values',\n",
    "    }\n",
    "    da_alpha = xr.DataArray(\n",
    "        data=alphas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"alpha_{str(scale).zfill(2)}\",\n",
    "        attrs=alpha_attrs,\n",
    "    )\n",
    "    beta_attrs = {\n",
    "        'description': '1 / scale of the distribution (also referred to as the rate) ' + \\\n",
    "        f'computed from the {scale}-month scaled precipitation values',\n",
    "    }\n",
    "    da_beta = xr.DataArray(\n",
    "        data=betas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"beta_{str(scale).zfill(2)}\",\n",
    "        attrs=beta_attrs,\n",
    "    )\n",
    "\n",
    "    return da_alpha, da_beta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function that can be used to compute the SPI for a particular month scale:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_spi_gamma(\n",
    "        da_precip: xr.DataArray,\n",
    "        da_alpha: xr.DataArray,\n",
    "        da_beta: xr.DataArray,\n",
    "        scale: int,\n",
    "        periodicity: compute.Periodicity,\n",
    ") -> xr.DataArray:\n",
    "    \n",
    "    initial_year = int(da_precip['time'][0].dt.year)\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    spi = np.full(shape=da_precip.shape, fill_value=np.NaN)\n",
    "\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "\n",
    "            # get the values for the lat/lon grid cell\n",
    "            values = da_precip[lat_index, lon_index]\n",
    "\n",
    "            # skip over this grid cell if all NaN values\n",
    "            if (np.ma.is_masked(values) and values.mask.all()) or np.all(np.isnan(values)):\n",
    "                continue\n",
    "\n",
    "            gamma_parameters = {\n",
    "                \"alphas\": da_alpha[lat_index, lon_index],\n",
    "                \"betas\": da_beta[lat_index, lon_index],\n",
    "            }\n",
    "\n",
    "            # compute the SPI\n",
    "            spi[lat_index, lon_index] = \\\n",
    "                indices.spi(\n",
    "                    values,\n",
    "                    scale=scale,\n",
    "                    distribution=indices.Distribution.gamma,\n",
    "                    data_start_year=initial_year,\n",
    "                    calibration_year_initial=calibration_year_initial,\n",
    "                    calibration_year_final=calibration_year_final,\n",
    "                    periodicity=compute.Periodicity.monthly,\n",
    "                    fitting_params=gamma_parameters,\n",
    "                )\n",
    "\n",
    "    # build a DataArray for this scale's SPI\n",
    "    da_spi = xr.DataArray(\n",
    "        data=spi,\n",
    "        coords=da_precip.coords,\n",
    "        dims=(\"lat\", \"lon\", \"time\"),\n",
    "        name=f\"spi_gamma_{str(scale).zfill(2)}\",\n",
    "    )\n",
    "    da_spi.attrs = {\n",
    "        'description': f'SPI ({scale}-{periodicity} gamma) computed from monthly precipitation ' + \\\n",
    "            f'data for the period {da_precip.time[0]} through {da_precip.time[-1]} using a ' + \\\n",
    "            f'calibration period from {calibration_year_initial} through {calibration_year_final}',\n",
    "        'valid_min': -3.09,\n",
    "        'valid_max': 3.09,\n",
    "        'long_name': f'{scale}-{periodicity} SPI(gamma)',\n",
    "        'calibration_year_initial': calibration_year_initial,\n",
    "        'calibration_year_final': calibration_year_final,\n",
    "    }\n",
    "\n",
    "    return da_spi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copy the attributes from the precipitation dataset that will be applicable to the corresponding gamma fitting parameters and SPI datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "attrs_to_copy = [\n",
    "    'Conventions',\n",
    "    'ncei_template_version',\n",
    "    'naming_authority',\n",
    "    'standard_name_vocabulary',\n",
    "    'institution',\n",
    "    'geospatial_lat_min',\n",
    "    'geospatial_lat_max',\n",
    "    'geospatial_lon_min',\n",
    "    'geospatial_lon_max',\n",
    "    'geospatial_lat_units',\n",
    "    'geospatial_lon_units',\n",
    "]\n",
    "global_attrs = {key: value for (key, value) in ds.attrs.items() if key in attrs_to_copy}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the gamma fitting parameters for all scales and add these into a Dataset that we'll write to NetCDF:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "if periodicity == compute.Periodicity.monthly:\n",
    "    period_times = 12\n",
    "    gamma_time_coord = \"month\"\n",
    "elif periodicity == compute.Periodicity.daily:\n",
    "    period_times = 366\n",
    "    gamma_time_coord = \"day\"\n",
    "ds_gamma = xr.Dataset(\n",
    "    coords={\"lat\": ds.lat, \"lon\": ds.lon, gamma_time_coord: range(period_times)},\n",
    "    attrs=global_attrs,\n",
    ")\n",
    "for scale in scales:\n",
    "    var_name_alpha = f\"alpha_{str(scale).zfill(2)}\"\n",
    "    var_name_beta = f\"beta_{str(scale).zfill(2)}\"\n",
    "    da_alpha, da_beta = compute_gammas(\n",
    "        da_prcp,\n",
    "        scale,\n",
    "        calibration_year_initial,\n",
    "        calibration_year_final,\n",
    "        periodicity,\n",
    "    )\n",
    "    ds_gamma[f\"alpha_{str(scale).zfill(2)}\"] = da_alpha\n",
    "    ds_gamma[f\"beta_{str(scale).zfill(2)}\"] = da_beta\n",
    "    \n",
    "netcdf_gamma = '/home/james/data/nclimgrid/nclimgrid_lowres_gamma.nc'\n",
    "ds_gamma.to_netcdf(netcdf_gamma)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the SPI using the pre-computed gamma fitting parameters for all scales and add these into a SPI(gamma) Dataset that we'll write to NetCDF:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "ds_spi = xr.Dataset(\n",
    "    coords=ds.coords,\n",
    "    attrs=global_attrs,\n",
    ")\n",
    "for scale in scales:\n",
    "    var_name_alpha = f\"alpha_{str(scale).zfill(2)}\"\n",
    "    var_name_beta = f\"beta_{str(scale).zfill(2)}\"\n",
    "    da_spi = compute_spi_gamma(\n",
    "        da_prcp: xr.DataArray,\n",
    "        ds_gamma[f'alpha_{str(scale).zfill(2)}'],\n",
    "        ds_gamma[f'beta_{str(scale).zfill(2)}'],\n",
    "        scale,\n",
    "        periodicity,\n",
    "    )\n",
    "    ds_spi[f\"spi_gamma_{str(scale).zfill(2)}\"] = da_spi\n",
    "    \n",
    "netcdf_spi = '/home/james/data/nclimgrid/nclimgrid_lowres_spi_gamma.nc'\n",
    "ds_spi.to_netcdf(netcdf_spi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot a time step to validate that the SPI values look reasonable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_spi[\"spi_gamma_03\"].isel(time=500).plot()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
